import numpy as np
from _code_ import kmeans_cost_label, algo1, k_means_cost, detAlg, samplingResult, hard_noisy_oracle, unpickle
from sklearn.cluster import KMeans
import random
# import csv
import math
from sklearn.datasets import load_digits
from sklearn.cluster import kmeans_plusplus as kpp
# from sklearn.cluster import KMeans
# from sklearn.neighbors import NearestNeighbors
# from sklearn.datasets import make_blobs
# from heapq import nlargest, nsmallest
# from sklearn.neighbors import KDTree
from sklearn.neighbors import BallTree
# from sklearn.metrics.pairwise import pairwise_distances
# from scipy.spatial.distance import cdist
# from sklearn.metrics.pairwise import euclidean_distances
# import maxtree
# from pyinstrument import Profiler
# from sklearn.neighbors import NearestNeighbors
from numba import jit


import time
# import heapq
from line_profiler import LineProfiler

# import argparse
import pandas as pd


@jit(nopython=True)
def find_minimum(dim_j_points, sample_id, n_neibor):
    minimum = 1E20
    best_point = -1
    dim_j_points = np.sort(dim_j_points)
    sum_l1 = dim_j_points ** 2
    for j1 in range(0, len(sample_id)):
        l = sample_id[j1]
        points_l = dim_j_points[l:l+n_neibor]
        sum_l = np.sum(points_l)
        cost = np.sum(sum_l1[l:l+n_neibor]) - sum_l**2 / points_l.shape[0]
        if (cost < minimum):
            minimum = cost
            best_point = sum_l / n_neibor
    return best_point

def Ours(points, oracle_labels, k, p_ours):
  
    n, d = points.shape
    
    # print("Method", k*d, n/100)
    
    centers = np.zeros((k, d))
    #sample_range = [i for i in range(0, n)]
    for i in range(0, k):
        # print(i)
        
        R = 2
        points_i = points[np.where(oracle_labels == i)[0]]
        n_neibor = math.floor((1 - p_ours) * points_i.shape[0])

        sample_id1 = random.sample(range(0, points_i.shape[0]-n_neibor), min(R, points_i.shape[0]-n_neibor))
        # sample_id =  random.sample(range(0, points_i.shape[0]), min(R,points_i.shape[0]-n_neibor))
        # sample_id = np.array(sample_id)
        sample_id1 = np.array(sample_id1) 

        for j in range(0, d):

            dim_j_points = points_i[:, j]
            minimum = 1E20
            best_point = None
            if(k*d > n/1E10):
                best_point = find_minimum(dim_j_points, sample_id1, n_neibor)
            else:
                for j1 in range(0, len(sample_id1)):
                    nearest_points = dim_j_points[0:n_neibor]
                    cost = np.std(nearest_points)
                    if (cost < minimum):
                        minimum = cost
                        best_point = np.average(nearest_points)
            centers[i][j] = best_point

    return centers


@jit(nopython=True)
def find_minimum1(dim_j_points, dim_j_temp, omega_j, weights, sample_id, outliers):
    minimum = 1E20
    best_point = None
    lam = 1
    group_length = sample_id.shape[0] / lam
    # print("Check", group_length)
    count = 1
    now = 0
    out = dim_j_temp[0:lam]
    for j1 in range(0, len(sample_id)):
        count += 1
        if(count==group_length or count == len(sample_id)):
            out[now] = best_point
            minimum = 1E20
            best_point = None
            now += 1
        dis = np.sum((omega_j - dim_j_points[sample_id[j1]]) ** 2, axis=1)
        nearest = np.argsort(dis)
        nearest = nearest[0:omega_j.shape[0] - outliers]
        cost_j1 = (dis[nearest]).sum()
        # print("CheckCost",sample_id[j1], cost_j1)
        if(cost_j1 < minimum):
            minimum = cost_j1
            best_point = dim_j_points[sample_id[j1]]
    return out


def Ours1(points, oracle_labels, k, p_ours):
    # print("Check", p_ours)
    n, d = points.shape
    centers = np.zeros((k, d))
    epsilon = 0.2
    for i in range(0, k):
        points_i = points[np.where(oracle_labels == i)[0]]
        n_neibor = math.floor((1 - p_ours) * points_i.shape[0])
        n_neibor_1 = math.floor((1-3*p_ours) * points_i.shape[0])
        R = 10
        epsilon = 1
        sample_size = math.log10((points_i.shape[0]**3)*d*(math.log10(n*1E4/(epsilon**2)))**3) * math.log10(points_i.shape[0]*1E4) / (epsilon**4)
        sample_size = min(int(sample_size), int(points_i.shape[0]/20))
        sample_size = max(sample_size, 2)
        # print("Check the sample size", sample_size)
        weights = np.ones(sample_size) * points_i.shape[0] / sample_size 

        
        outliers = math.floor((p_ours * 1.3 * math.ceil(sample_size)))
        outliers = max(outliers, 1)
        

        
        dim_j_points = points_i
        omega_j = random.sample(range(0,points_i.shape[0]), sample_size)
        omega_j = dim_j_points[omega_j]
        # omega_j = omega[:,j]
        sample_id =  random.sample(range(0, dim_j_points.shape[0]), R)
        # print("CheckSampling",i,j, sample_size, dim_j_points.shape[0])
        sample_id = np.array(sample_id)
        best_point = find_minimum1(dim_j_points.copy(), dim_j_points.copy(), omega_j, weights, sample_id, outliers)
        
        minimum = 1E20
        best_center = None
        
        # print("Check", len(best_point))
        
        for j1 in range(0, len(best_point)):
            dis_j1 = np.sum((dim_j_points - best_point[j1]) ** 2, axis=1)
            n_id = np.argpartition(dis_j1, n_neibor_1)[0:n_neibor_1]
            nearest_points = points_i[n_id]
            center = np.average(nearest_points, axis=0)
            cost_now = np.sum(nearest_points ** 2) - nearest_points.shape[0] * np.sum(center ** 2)
            if (cost_now < minimum):
                minimum = cost_now
                best_center = center
        centers[i] = best_center
    return centers


if __name__ == '__main__':
    result = pd.DataFrame(columns=['dataset', 'k', 'alpha', 'method', 'cost', 'cost_dev', 'time', 'time_dev'])

    # data_name = ['phy']
    # data_name = ['cifar10','mnist']
    data_name = ['mnist']
    # data_name = ['A1','A2','A3','S1','S2','S3','S4']kill
    # data_name = ['SUSY']
    #err_range = [0.2]
    err_range = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
    nTrials = 15
    nIters = 5
    #k_range = [10,20,30,40,50]
    k_range = [20]
    nPortion = 1

    for i11 in range(0, len(data_name)):
        dataset = data_name[i11]

        print('loading data')
        if dataset == 'cifar10':
            data = unpickle("test1.dat")
            data = data[b'data']
            np.random.shuffle(data)
            nPortion = int(len(data) * nPortion)
            test = data[-nPortion:]
        elif dataset == 'phy':
            data = np.loadtxt("phy.dat")
            nPortion = int(len(data) * nPortion)
            np.random.shuffle(data)
            test = data[-nPortion:, :]
        elif dataset == 'mnist':
            data = load_digits().data
            nPortion = int(len(data) * nPortion)
            test = data[-nPortion:]
        elif dataset == 'S1':
            k = 15
            data = np.loadtxt('s1.txt')
            gt = np.loadtxt('s1-cb.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'S2':
            data = np.loadtxt('s2.txt')
            gt = np.loadtxt('s2-cb.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            k = 15
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'S3':
            k = 15
            data = np.loadtxt('s3.txt')
            gt = np.loadtxt('s3-cb.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'S4':
            k = 15
            data = np.loadtxt('s4.txt')
            gt = np.loadtxt('s4-cb.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'A1':
            k = 20
            data = np.loadtxt('a1.txt')
            gt = np.loadtxt('a1-ga-cb.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'A2':
            k = 35
            data = np.loadtxt('a2.txt')
            gt = np.loadtxt('a2-ga-cb.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'A3':
            k = 50
            data = np.loadtxt('a3.txt')
            gt = np.loadtxt('a3-ga-cb.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'birch1':
            data = np.loadtxt('birch1.txt')
            gt = np.loadtxt('b1-gt.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'birch2':
            data = np.loadtxt('birch2.txt')
            gt = np.loadtxt('b2-gt.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'birch3':
            data = np.loadtxt('birch3.txt')
            gt = np.loadtxt('b3-gt.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'unbalance':
            data = np.loadtxt('unbalance.txt')
            gt = np.loadtxt('unbalance-gt.txt')
            Tree = BallTree(gt, leaf_size=40)
            dist, ind = Tree.query(data, k=1)
            test = data.copy()
            noisy_orc_labels = ind[:, 0]
            cost_opt = (dist[:, 0] ** 2).sum()
        elif dataset == 'kdd':
            data = np.loadtxt("kdd.txt", delimiter=",", usecols=(
                0, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
                34, 35,
                36, 37, 38, 39, 40))
            data = data[:, 0:data.shape[1] - 1]
            nPortion = int(len(data) * nPortion)
            np.random.shuffle(data)
            test = data[-nPortion:, :]
        elif dataset == "SUSY":
            data = np.loadtxt('data/SUSY.csv', usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                              delimiter=",")
            data = data[:, 0:data.shape[1] - 1]
            nPortion = int(len(data) * nPortion)
            np.random.shuffle(data)
            test = data[-nPortion:, :]
        elif dataset == 'HIGGS':
            data = np.loadtxt('HIGGS.csv', delimiter=",")
            data = data[:, 1:data.shape[1]]
            nPortion = int(len(data) * nPortion)
            np.random.shuffle(data)
            test = data[-nPortion:, :]

        for i12 in range(0, len(err_range)):

            err = err_range[i12]
            for i13 in range(0, len(k_range)):
                k = k_range[i13]

                print("Dataset size:", len(test), "k:", k, "error:", err)

                # test, _ = make_blobs(n_samples=1000000, n_features=10, centers=10)

                # pairwise_distances(X, kwargs)

                start = time.time()
                time_OPT = time.time() - start
                tkmeans = time.time()
                kmeans_scikit_10 = KMeans(n_clusters=k).fit(test)
                true_labels_10 = kmeans_scikit_10.labels_
                tkmeans1 = time.time()

                # #PlanB
                # CRT = KMeans(n_clusters=k,init="k-means++",n_init=1,max_iter=1)
                # CRT.fit(data)
                # centers = CRT.cluster_centers_

                # centers = Projection(data, centers)
                # LS_FAST = LS(n_clusters=k, rounds=500)
                # centers_f = LS_FAST.Fast_LS(data.copy(), centers)

                # "Calculating the Final Clustering Centers and Labels"
                # CRT1 = KMeans(n_clusters = k, init = centers_f, n_init=1, max_iter=300)
                # CRT1.fit(data)
                # centers_f = CRT1.cluster_centers_
                # true_labels_10 = CRT1.labels_

                print('Predictor: {} corruption (avg over {} trials)'.format(err, nIters))

                pvals_alg1 = np.linspace(.01, 0.5, nTrials)
                pvals_det = np.linspace(.01, .5, nTrials)
                pvals_ours = np.linspace(.01, .5, nTrials)
                pvals_ours1 = np.linspace(.01, .3, nTrials)
                
                #print(pvals_alg1)

                cost_oracle = []
                cost_sampling = []
                cost_algo = []
                det_cost_algo = []
                baseline_10 = []

                cost_ours = []
                cost_ours1 = []
                ours_time = []
                ours_time1 = []
                baseline_time = []

                cost_opt = [kmeans_cost_label(test, true_labels_10, k)[1]]
                cost_opt = cost_opt[0]
                # noisy_orc_labels = hard_noisy_oracle(test,noisy_orc_labels, err)
                noisy_orc_labels = hard_noisy_oracle(test, true_labels_10, err)

                time_sampling = []
                time_alg = []
                det_time_alg = []
                time_OPT = [time_OPT]

                for i in range(nIters):
                    baseline_10.append(k_means_cost(test, kpp(test, k)[0])[1])
                    baseline_time.append(tkmeans1 - tkmeans)
                    cost_oracle.append(kmeans_cost_label(test, noisy_orc_labels, k)[1])

                    start = time.time()
                    # SR = samplingResult(test, noisy_orc_labels, k)
                    SR = -1
                    time_sampling.append(time.time() - start)
                    cost_sampling.append(SR)
                    lowest = float('inf')
                    det_lowest = float('inf')
                    ours_lowest = float('inf')
                    ours_lowest1 = float('inf')
                    alg1_time_count = 0
                    det_time_count = 0
                    ours_time_count = 0
                    ours_time_count1 = 0
                    
                    #print(pvals_alg1)
                    
                    for p_alg1, p_det, p_ours, p_ours1 in zip(pvals_alg1, pvals_det, pvals_ours, pvals_ours1):
                        start = time.time()
                        # print(p_alg1)
                        cr = algo1(test, noisy_orc_labels.copy(), k, p_alg1)
                        alg1_time_count += time.time() - start
                    

                        # print(time.time() - start)

                        # print("Determin Start")

                        start = time.time()
                        dr = detAlg(test, noisy_orc_labels.copy(), k, p_det)
                        det_time_count += time.time() - start

                        # print(time.time() - start)

                        # print("Ours Start")
                        start = time.time()
                        orr = Ours(test, noisy_orc_labels.copy(), k, p_ours)
                        ours_time_count += time.time() - start

                        start = time.time()
                        orr1 = Ours1(test, noisy_orc_labels.copy(), k, p_ours1)
                        ours_time_count1 += time.time() - start

                        # print(time.time() - start)

                        # lp = LineProfiler()
                        # lp_wrapper = lp(Ours1)
                        # lp_wrapper(test, noisy_orc_labels.copy(), k, p_ours)
                        # lp.print_stats()

                        # time.time() - start

                        curr_cost = k_means_cost(test, cr)[1]
                        det_curr_cost = k_means_cost(test, dr)[1]
                        ours_cost = k_means_cost(test, orr)[1]
                        ours_cost1 = k_means_cost(test, orr1)[1]
                        
                        # print("Ours1", ours_cost1)
                        
                        # print("Det", det_curr_cost)
                        # ours_cost1 = k_means_cost(test, orr1)[1]

                        if det_curr_cost < det_lowest:
                            det_lowest = det_curr_cost
                        if curr_cost < lowest:
                            lowest = curr_cost
                        if ours_cost < ours_lowest:
                            ours_lowest = ours_cost
                        if ours_cost1 < ours_lowest1:
                            ours_lowest1 = ours_cost1

                    cost_algo.append(lowest)
                    det_cost_algo.append(det_lowest)

                    time_alg.append(alg1_time_count)
                    det_time_alg.append(det_time_count)

                    cost_ours.append(ours_lowest)
                    ours_time.append(ours_time_count)

                    cost_ours1.append(ours_lowest1)
                    ours_time1.append(ours_time_count1)

                print('kmeans++:', np.average(baseline_10), np.std(baseline_10), "Time", np.average(baseline_time), np.std(baseline_time))
                print('Algo1:', np.average(cost_algo), np.std(cost_algo), "Time", np.average(time_alg), np.std(time_alg))
                print('Det:', np.average(det_cost_algo), np.std(det_cost_algo), "Time", np.average(det_time_alg),np.std(det_time_alg))
                print('Ours:', np.average(cost_ours), np.std(cost_ours), "Time", np.average(ours_time), np.std(ours_time))
                print('Ours1:', np.average(cost_ours1), np.std(cost_ours1), "Time", np.average(ours_time1), np.std(ours_time1))
                print('Optimal', cost_opt)

                new = pd.DataFrame(
                    [[dataset, k, err, "kmeans++", np.average(baseline_10), 0, 0, 0]],
                    columns=['dataset', 'k', 'alpha', 'method', 'cost', 'cost_dev', 'time', 'time_dev'])
                result = pd.concat([result, new])

                new = pd.DataFrame(
                    [[dataset, k, err, "Algo1", np.average(cost_algo), np.std(cost_algo), np.average(time_alg),
                      np.std(time_alg)]],
                    columns=['dataset', 'k', 'alpha', 'method', 'cost', 'cost_dev', 'time', 'time_dev'])
                result = pd.concat([result, new])

                new = pd.DataFrame(
                    [[dataset, k, err, "Det", np.average(det_cost_algo), np.std(det_cost_algo),
                      np.average(det_time_alg), np.std(det_time_alg)]],
                    columns=['dataset', 'k', 'alpha', 'method', 'cost', 'cost_dev', 'time', 'time_dev'])
                result = pd.concat([result, new])

                new = pd.DataFrame(
                    [[dataset, k, err, "Ours", np.average(cost_ours), np.std(cost_ours), np.average(ours_time),
                      np.std(ours_time)]],
                    columns=['dataset', 'k', 'alpha', 'method', 'cost', 'cost_dev', 'time', 'time_dev'])
                result = pd.concat([result, new])

                new = pd.DataFrame(
                    [[dataset, k, err, "Ours1", np.average(cost_ours1), np.std(cost_ours1), np.average(ours_time1), np.std(ours_time1)]],
                    columns=['dataset', 'k', 'alpha', 'method', 'cost', 'cost_dev', 'time', 'time_dev'])
                result = pd.concat([result, new])

                new = pd.DataFrame(
                    [[dataset, k, err, "OPT", cost_opt, 0, 0, 0]],
                    columns=['dataset', 'k', 'alpha', 'method', 'cost', 'cost_dev', 'time', 'time_dev'])
                result = pd.concat([result, new])

    # resultFileFormatted = dataset + 'KmeansResult.csv'
    # resultFileFormattedTime = dataset + 'KmeansResultTime.csv'

    # load data

    # result.to_csv("cifar10_large_k.csv", index=False)
    '''
    cost_oracle = np.array(cost_oracle) 
    cost_sampling = np.array(cost_sampling) 
    cost_algo = np.array(cost_algo) 
    det_cost_algo = np.array(det_cost_algo)
    baseline_10 = np.array(baseline_10)
    result = np.array([baseline_10, cost_oracle, cost_sampling, cost_algo, det_cost_algo]).T
    mean = np.mean(result, axis = 0)
    std = np.std(result, axis = 0)
    result = np.expand_dims(np.append(mean, std), 0)
    result = np.append(result, [[np.average(cost_opt)]], axis = 1)

    header = ["Params", "k++" ,"Oracle", "Sampling" , "Ergun, Jon, et al.", "Ours", "k++EB" , "OracleEB", "SamplingEB" ,"Ergun, Jon, et al. EB", "OursEB", "OPT"]
    if args.overwrite:
        w = 'w'
    else:
        w = 'a'

    params = ["Error {} K {} Num Trials {}".format(args.err, args.k, args.nTrials)]
    params_result = list(map(str, result[0].tolist()))
    params.extend(params_result)

    with open(resultFileFormatted,w) as fd:
        writer = csv.writer(fd, delimiter=',')
        if args.overwrite:
            writer.writerows([header, params])
        else:
            writer.writerows([params])


    #######################################

    result = np.array([cost_sampling, cost_algo, det_cost_algo]).T
    mean = [np.mean(x) for x in [time_sampling, time_alg, det_time_alg]]
    std = [np.std(x) for x in [time_sampling, time_alg, det_time_alg]]
    result = np.expand_dims(np.append(mean, std), 0)
    result = np.append(result, [[np.average(time_OPT)]], axis = 1)

    header = ["Params",  "Sampling" , "Ergun, Jon, et al.", "Ours", "Sampling++EB" ,"Ergun, Jon, et al. EB", "OursEB", "OPT"]
    if args.overwrite:
        w = 'w'
    else:
        w = 'a'

    params = ["Error {} K {} Num Trials {} Dataset Portion {}".format(args.err, args.k, args.nTrials, args.nPortion)]
    params_result = list(map(str, result[0].tolist()))
    params.extend(params_result)

    with open(resultFileFormattedTime,w) as fd:
        writer = csv.writer(fd, delimiter=',')
        if args.overwrite:
            writer.writerows([header, params])
        else:
            writer.writerows([params])
    '''